package ru.demax.rhythmerr.audio.recognition.nn;

import com.google.firebase.remoteconfig.FirebaseRemoteConfig;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Locale;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.exceptions.NeurophException;
import org.neuroph.core.learning.error.ErrorFunction;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;

/* loaded from: classes2.dex */
public class SoundRecognitionNeuralNetwork {
    private static final String DEFAULT_MODEL_FILE_PATH = "SoundEventRecognition/src/main/resources/model.nnet";
    private static final String DEFAULT_MODEL_RESOURCE_PATH = "/model.nnet";
    private final NeuralNetworkConfig config;
    private final MultiLayerPerceptron neuralNetwork;

    /* loaded from: classes2.dex */
    private class AvgMaxErrorFunction implements ErrorFunction {
        private int n;
        private double sum;

        private AvgMaxErrorFunction() {
        }

        @Override // org.neuroph.core.learning.error.ErrorFunction
        public void addOutputError(double[] dArr) {
            double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
            for (double d2 : dArr) {
                d = Math.max(Math.abs(d2), d);
            }
            if (d > 0.45d) {
                d = 1.0d;
            }
            this.sum += d;
            this.n++;
        }

        @Override // org.neuroph.core.learning.error.ErrorFunction
        public double getTotalError() {
            int i = this.n;
            if (i <= 0) {
                return FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
            }
            double d = this.sum;
            double d2 = i;
            Double.isNaN(d2);
            return d / d2;
        }

        @Override // org.neuroph.core.learning.error.ErrorFunction
        public void reset() {
            this.n = 0;
            this.sum = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        }
    }

    public SoundRecognitionNeuralNetwork(int i, int i2) {
        this(new NeuralNetworkConfig(i, i2 * 30, i2));
    }

    private SoundRecognitionNeuralNetwork(NeuralNetworkConfig neuralNetworkConfig) {
        this.config = neuralNetworkConfig;
        this.neuralNetwork = neuralNetworkConfig.buildNetwork();
    }

    public static SoundRecognitionNeuralNetwork load(InputStream inputStream) {
        try {
            try {
                return new SoundRecognitionNeuralNetwork((NeuralNetworkConfig) new ObjectInputStream(new BufferedInputStream(inputStream)).readObject());
            } finally {
            }
        } catch (Exception e) {
            throw new NeurophException("Unable to load model from file", e);
        }
    }

    public static SoundRecognitionNeuralNetwork loadFromDefaultResource() {
        return load(SoundRecognitionNeuralNetwork.class.getResourceAsStream(DEFAULT_MODEL_RESOURCE_PATH));
    }

    public double[] calc(float[] fArr) {
        Neuron[] inputNeurons = this.neuralNetwork.getInputNeurons();
        for (int i = 0; i < fArr.length; i++) {
            inputNeurons[i].setInput(fArr[i]);
        }
        this.neuralNetwork.calculate();
        return this.neuralNetwork.getOutput();
    }

    public DataSet createDataSet() {
        return new DataSet(this.neuralNetwork.getInputsCount(), this.neuralNetwork.getOutputsCount());
    }

    public void save() {
        try {
            try {
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(DEFAULT_MODEL_FILE_PATH)));
                this.config.setWeights(this.neuralNetwork.getWeights());
                objectOutputStream.writeObject(this.config);
                objectOutputStream.flush();
            } finally {
            }
        } catch (IOException e) {
            throw new NeurophException("Unable to write model to file", e);
        }
    }

    public void train(DataSet dataSet, double d) {
        final BackPropagation learningRule = this.neuralNetwork.getLearningRule();
        learningRule.setMaxError(d);
        learningRule.setMinErrorChange(d / 10.0d);
        learningRule.setMinErrorChangeIterationsLimit(50);
        learningRule.addListener(new LearningEventListener() { // from class: ru.demax.rhythmerr.audio.recognition.nn.SoundRecognitionNeuralNetwork.1
            @Override // org.neuroph.core.events.LearningEventListener
            public void handleLearningEvent(LearningEvent learningEvent) {
                Integer currentIteration = learningRule.getCurrentIteration();
                System.out.printf(Locale.US, "%s: iteration %d, error %.5f%n", learningEvent.getEventType(), currentIteration, Double.valueOf(learningRule.getErrorFunction().getTotalError()));
                if (currentIteration.intValue() <= 300 || currentIteration.intValue() % 100 != 0) {
                    return;
                }
                System.out.print("Saving snapshot... ");
                SoundRecognitionNeuralNetwork.this.save();
                System.out.println("Done");
            }
        });
        learningRule.setErrorFunction(new AvgMaxErrorFunction());
        this.neuralNetwork.learn(dataSet);
        save();
    }
}
